from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *

#
# This section runs an example numerical reconstruction
# to report in the supplement

probe_maxk = 128

# Set GPU to false if you don't have a GPU
GPU = True

field_size = int(np.ceil(probe_maxk)) * 2
# This sets the focal spot size of our BLR probe
probe_fov_radius = (field_size * 4)//10

# This defines that we do our error comparison within the central
# half of the image.
error_check_padding = 1/ 4


bs_factor = 0.5



resolution_ratios = np.linspace(0.1,2,20)
full_pixel_intensity_stack = []
for resolution_ratio in resolution_ratios:
    print('resolution ratio', resolution_ratio)
    # This defines a reconstruction with R = 0.5
    resolution = int(probe_maxk * resolution_ratio) * 2
    pixel_intensity_stack = []
    nrepeats = int(10 * np.max(resolution_ratios)**2 / resolution_ratio**2)
    pad = int(128 * 2 * 0.25 * resolution_ratio)
    
    for repeat in range(nrepeats):
        print('repeat',repeat,'of',nrepeats, end='\r')
        # This generates an ideal BLR probe with the given parameters
        nearfield = generate_blr_probe([field_size,field_size],
                                       probe_fov_radius, 
                                       probe_maxk,
                                       beamstop_factor=bs_factor)
        pixel_intensities = np.zeros((resolution, resolution))
        original_image = np.zeros((resolution, resolution), dtype=np.complex64)
        original_image = complex_to_torch(original_image).to(dtype=t.float32)
        # All this does is upsample the probe in real space to avoid
        # aliasing in the probe-object interaction
        simulation_nearfield = expand_probe(nearfield, original_image)
        
        if GPU==True:
            simulation_nearfield = simulation_nearfield.to(device='cuda:0')
            original_image = original_image.to(device='cuda:0')

        last_pixel = None
        for pixel in range(resolution**2):
            if pixel/resolution < pad or pixel/resolution > (resolution-pad):
                continue
            elif pixel%resolution < pad or pixel%resolution > (resolution-pad):
                continue
            # We generate an image with a single pixel lit up
            if last_pixel is not None:
                original_image.view(resolution**2,2)[last_pixel,0] = 0
                    
            original_image.view(resolution**2,2)[pixel,0] = 1
            last_pixel = pixel
        
            # Now we perform the simulation
            exit_wave = interact(original_image, simulation_nearfield)
            intensities = cabssq(exit_wave)
            pixel_intensities.ravel()[pixel] = np.float32(t.sum(intensities).cpu().numpy())

        # only save the central region
        pixel_intensity_stack.append(pixel_intensities[pad:-pad,pad:-pad])
    full_pixel_intensity_stack.append(pixel_intensity_stack)
    
results = {'resolution ratios': resolution_ratios, 'beamstop_factor': bs_factor, 'intensities': full_pixel_intensity_stack}
with open('data/uniformity trial 2.pickle', 'wb') as f:
    pickle.dump(results, f)
